{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用gensim训练word2vec\n",
    "\n",
    "本DEMO只使用部分数据，使用全部数据预训练的词向量地址：  \n",
    "\n",
    "链接: https://pan.baidu.com/s/1ewlck3zwXVQuAzraZ26Euw 提取码: qbpr "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fab09a064b0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import logging\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(levelname)s: %(message)s')\n",
    "\n",
    "# set seed\n",
    "seed = 666\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-07-18 23:30:04,912 INFO: Fold lens [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]\n"
     ]
    }
   ],
   "source": [
    "# split data to 10 fold\n",
    "fold_num = 10\n",
    "data_file = '../data/train_set.csv'\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "def all_data2fold(fold_num, num=10000):\n",
    "    fold_data = []\n",
    "    f = pd.read_csv(data_file, sep='\\t', encoding='UTF-8')\n",
    "    texts = f['text'].tolist()[:num]\n",
    "    labels = f['label'].tolist()[:num]\n",
    "\n",
    "    total = len(labels)\n",
    "\n",
    "    index = list(range(total))\n",
    "    np.random.shuffle(index)\n",
    "\n",
    "    all_texts = []\n",
    "    all_labels = []\n",
    "    for i in index:\n",
    "        all_texts.append(texts[i])\n",
    "        all_labels.append(labels[i])\n",
    "\n",
    "    label2id = {}\n",
    "    for i in range(total):\n",
    "        label = str(all_labels[i])\n",
    "        if label not in label2id:\n",
    "            label2id[label] = [i]\n",
    "        else:\n",
    "            label2id[label].append(i)\n",
    "\n",
    "    all_index = [[] for _ in range(fold_num)]\n",
    "    for label, data in label2id.items():\n",
    "        # print(label, len(data))\n",
    "        batch_size = int(len(data) / fold_num)\n",
    "        other = len(data) - batch_size * fold_num\n",
    "        for i in range(fold_num):\n",
    "            cur_batch_size = batch_size + 1 if i < other else batch_size\n",
    "            # print(cur_batch_size)\n",
    "            batch_data = [data[i * batch_size + b] for b in range(cur_batch_size)]\n",
    "            all_index[i].extend(batch_data)\n",
    "\n",
    "    batch_size = int(total / fold_num)\n",
    "    other_texts = []\n",
    "    other_labels = []\n",
    "    other_num = 0\n",
    "    start = 0\n",
    "    for fold in range(fold_num):\n",
    "        num = len(all_index[fold])\n",
    "        texts = [all_texts[i] for i in all_index[fold]]\n",
    "        labels = [all_labels[i] for i in all_index[fold]]\n",
    "\n",
    "        if num > batch_size:\n",
    "            fold_texts = texts[:batch_size]\n",
    "            other_texts.extend(texts[batch_size:])\n",
    "            fold_labels = labels[:batch_size]\n",
    "            other_labels.extend(labels[batch_size:])\n",
    "            other_num += num - batch_size\n",
    "        elif num < batch_size:\n",
    "            end = start + batch_size - num\n",
    "            fold_texts = texts + other_texts[start: end]\n",
    "            fold_labels = labels + other_labels[start: end]\n",
    "            start = end\n",
    "        else:\n",
    "            fold_texts = texts\n",
    "            fold_labels = labels\n",
    "\n",
    "        assert batch_size == len(fold_labels)\n",
    "\n",
    "        # shuffle\n",
    "        index = list(range(batch_size))\n",
    "        np.random.shuffle(index)\n",
    "\n",
    "        shuffle_fold_texts = []\n",
    "        shuffle_fold_labels = []\n",
    "        for i in index:\n",
    "            shuffle_fold_texts.append(fold_texts[i])\n",
    "            shuffle_fold_labels.append(fold_labels[i])\n",
    "\n",
    "        data = {'label': shuffle_fold_labels, 'text': shuffle_fold_texts}\n",
    "        fold_data.append(data)\n",
    "\n",
    "    logging.info(\"Fold lens %s\", str([len(data['label']) for data in fold_data]))\n",
    "\n",
    "    return fold_data\n",
    "\n",
    "\n",
    "fold_data = all_data2fold(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-07-18 23:30:04,929 INFO: Total 9000 docs.\n"
     ]
    }
   ],
   "source": [
    "# build train data for word2vec\n",
    "fold_id = 9\n",
    "\n",
    "train_texts = []\n",
    "for i in range(0, fold_id):\n",
    "    data = fold_data[i]\n",
    "    train_texts.extend(data['text'])\n",
    "    \n",
    "logging.info('Total %d docs.' % len(train_texts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-07-18 23:30:04,938 INFO: Start training...\n",
      "2020-07-18 23:30:05,545 INFO: collecting all words and their counts\n",
      "2020-07-18 23:30:05,546 INFO: PROGRESS: at sentence #0, processed 0 words, keeping 0 word types\n",
      "2020-07-18 23:30:06,231 INFO: collected 5295 word types from a corpus of 8191447 raw words and 9000 sentences\n",
      "2020-07-18 23:30:06,231 INFO: Loading a fresh vocabulary\n",
      "2020-07-18 23:30:06,305 INFO: effective_min_count=5 retains 4335 unique words (81% of original 5295, drops 960)\n",
      "2020-07-18 23:30:06,305 INFO: effective_min_count=5 leaves 8189498 word corpus (99% of original 8191447, drops 1949)\n",
      "2020-07-18 23:30:06,314 INFO: deleting the raw counts dictionary of 5295 items\n",
      "2020-07-18 23:30:06,316 INFO: sample=0.001 downsamples 61 most-common words\n",
      "2020-07-18 23:30:06,316 INFO: downsampling leaves estimated 7070438 word corpus (86.3% of prior 8189498)\n",
      "2020-07-18 23:30:06,324 INFO: estimated required memory for 4335 words and 100 dimensions: 5635500 bytes\n",
      "2020-07-18 23:30:06,325 INFO: resetting layer weights\n",
      "2020-07-18 23:30:06,356 INFO: training model with 8 workers on 4335 vocabulary and 100 features, using sg=0 hs=0 sample=0.001 negative=5 window=5\n",
      "2020-07-18 23:30:07,360 INFO: EPOCH 1 - PROGRESS: at 63.20% examples, 4434412 words/s, in_qsize 15, out_qsize 0\n",
      "2020-07-18 23:30:07,959 INFO: worker thread finished; awaiting finish of 7 more threads\n",
      "2020-07-18 23:30:07,960 INFO: worker thread finished; awaiting finish of 6 more threads\n",
      "2020-07-18 23:30:07,963 INFO: worker thread finished; awaiting finish of 5 more threads\n",
      "2020-07-18 23:30:07,963 INFO: worker thread finished; awaiting finish of 4 more threads\n",
      "2020-07-18 23:30:07,964 INFO: worker thread finished; awaiting finish of 3 more threads\n",
      "2020-07-18 23:30:07,965 INFO: worker thread finished; awaiting finish of 2 more threads\n",
      "2020-07-18 23:30:07,968 INFO: worker thread finished; awaiting finish of 1 more threads\n",
      "2020-07-18 23:30:07,969 INFO: worker thread finished; awaiting finish of 0 more threads\n",
      "2020-07-18 23:30:07,970 INFO: EPOCH - 1 : training on 8191447 raw words (7021120 effective words) took 1.6s, 4357567 effective words/s\n",
      "2020-07-18 23:30:08,979 INFO: EPOCH 2 - PROGRESS: at 59.39% examples, 4143643 words/s, in_qsize 15, out_qsize 0\n",
      "2020-07-18 23:30:09,661 INFO: worker thread finished; awaiting finish of 7 more threads\n",
      "2020-07-18 23:30:09,663 INFO: worker thread finished; awaiting finish of 6 more threads\n",
      "2020-07-18 23:30:09,663 INFO: worker thread finished; awaiting finish of 5 more threads\n",
      "2020-07-18 23:30:09,664 INFO: worker thread finished; awaiting finish of 4 more threads\n",
      "2020-07-18 23:30:09,667 INFO: worker thread finished; awaiting finish of 3 more threads\n",
      "2020-07-18 23:30:09,667 INFO: worker thread finished; awaiting finish of 2 more threads\n",
      "2020-07-18 23:30:09,670 INFO: worker thread finished; awaiting finish of 1 more threads\n",
      "2020-07-18 23:30:09,672 INFO: worker thread finished; awaiting finish of 0 more threads\n",
      "2020-07-18 23:30:09,672 INFO: EPOCH - 2 : training on 8191447 raw words (7021506 effective words) took 1.7s, 4144060 effective words/s\n",
      "2020-07-18 23:30:10,681 INFO: EPOCH 3 - PROGRESS: at 59.52% examples, 4154672 words/s, in_qsize 15, out_qsize 0\n",
      "2020-07-18 23:30:11,356 INFO: worker thread finished; awaiting finish of 7 more threads\n",
      "2020-07-18 23:30:11,356 INFO: worker thread finished; awaiting finish of 6 more threads\n",
      "2020-07-18 23:30:11,358 INFO: worker thread finished; awaiting finish of 5 more threads\n",
      "2020-07-18 23:30:11,359 INFO: worker thread finished; awaiting finish of 4 more threads\n",
      "2020-07-18 23:30:11,362 INFO: worker thread finished; awaiting finish of 3 more threads\n",
      "2020-07-18 23:30:11,362 INFO: worker thread finished; awaiting finish of 2 more threads\n",
      "2020-07-18 23:30:11,365 INFO: worker thread finished; awaiting finish of 1 more threads\n",
      "2020-07-18 23:30:11,366 INFO: worker thread finished; awaiting finish of 0 more threads\n",
      "2020-07-18 23:30:11,367 INFO: EPOCH - 3 : training on 8191447 raw words (7020706 effective words) took 1.7s, 4163417 effective words/s\n",
      "2020-07-18 23:30:12,378 INFO: EPOCH 4 - PROGRESS: at 58.80% examples, 4102329 words/s, in_qsize 15, out_qsize 0\n",
      "2020-07-18 23:30:13,072 INFO: worker thread finished; awaiting finish of 7 more threads\n",
      "2020-07-18 23:30:13,078 INFO: worker thread finished; awaiting finish of 6 more threads\n",
      "2020-07-18 23:30:13,079 INFO: worker thread finished; awaiting finish of 5 more threads\n",
      "2020-07-18 23:30:13,079 INFO: worker thread finished; awaiting finish of 4 more threads\n",
      "2020-07-18 23:30:13,080 INFO: worker thread finished; awaiting finish of 3 more threads\n",
      "2020-07-18 23:30:13,080 INFO: worker thread finished; awaiting finish of 2 more threads\n",
      "2020-07-18 23:30:13,081 INFO: worker thread finished; awaiting finish of 1 more threads\n",
      "2020-07-18 23:30:13,082 INFO: worker thread finished; awaiting finish of 0 more threads\n",
      "2020-07-18 23:30:13,083 INFO: EPOCH - 4 : training on 8191447 raw words (7021984 effective words) took 1.7s, 4117851 effective words/s\n",
      "2020-07-18 23:30:14,091 INFO: EPOCH 5 - PROGRESS: at 58.99% examples, 4115963 words/s, in_qsize 16, out_qsize 0\n",
      "2020-07-18 23:30:14,769 INFO: worker thread finished; awaiting finish of 7 more threads\n",
      "2020-07-18 23:30:14,770 INFO: worker thread finished; awaiting finish of 6 more threads\n",
      "2020-07-18 23:30:14,770 INFO: worker thread finished; awaiting finish of 5 more threads\n",
      "2020-07-18 23:30:14,771 INFO: worker thread finished; awaiting finish of 4 more threads\n",
      "2020-07-18 23:30:14,773 INFO: worker thread finished; awaiting finish of 3 more threads\n",
      "2020-07-18 23:30:14,776 INFO: worker thread finished; awaiting finish of 2 more threads\n",
      "2020-07-18 23:30:14,777 INFO: worker thread finished; awaiting finish of 1 more threads\n",
      "2020-07-18 23:30:14,779 INFO: worker thread finished; awaiting finish of 0 more threads\n",
      "2020-07-18 23:30:14,779 INFO: EPOCH - 5 : training on 8191447 raw words (7021532 effective words) took 1.7s, 4156171 effective words/s\n",
      "2020-07-18 23:30:14,780 INFO: training on a 40957235 raw words (35106848 effective words) took 8.4s, 4167675 effective words/s\n",
      "2020-07-18 23:30:14,780 INFO: precomputing L2-norms of word weight vectors\n",
      "2020-07-18 23:30:14,782 INFO: saving Word2Vec object under ./word2vec.bin, separately None\n",
      "2020-07-18 23:30:14,783 INFO: not storing attribute vectors_norm\n",
      "2020-07-18 23:30:14,783 INFO: not storing attribute cum_table\n",
      "2020-07-18 23:30:14,820 INFO: saved ./word2vec.bin\n"
     ]
    }
   ],
   "source": [
    "logging.info('Start training...')\n",
    "from gensim.models.word2vec import Word2Vec\n",
    "\n",
    "num_features = 100     # Word vector dimensionality\n",
    "num_workers = 8       # Number of threads to run in parallel\n",
    "\n",
    "train_texts = list(map(lambda x: list(x.split()), train_texts))\n",
    "model = Word2Vec(train_texts, workers=num_workers, size=num_features)\n",
    "model.init_sims(replace=True)\n",
    "\n",
    "# save model\n",
    "model.save(\"./word2vec.bin\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-07-18 23:30:14,825 INFO: loading Word2Vec object from ./word2vec.bin\n",
      "2020-07-18 23:30:14,958 INFO: loading wv recursively from ./word2vec.bin.wv.* with mmap=None\n",
      "2020-07-18 23:30:14,958 INFO: setting ignored attribute vectors_norm to None\n",
      "2020-07-18 23:30:14,959 INFO: loading vocabulary recursively from ./word2vec.bin.vocabulary.* with mmap=None\n",
      "2020-07-18 23:30:14,959 INFO: loading trainables recursively from ./word2vec.bin.trainables.* with mmap=None\n",
      "2020-07-18 23:30:14,959 INFO: setting ignored attribute cum_table to None\n",
      "2020-07-18 23:30:14,959 INFO: loaded ./word2vec.bin\n",
      "2020-07-18 23:30:14,965 INFO: storing 4335x100 projection weights into ./word2vec.txt\n"
     ]
    }
   ],
   "source": [
    "# load model\n",
    "model = Word2Vec.load(\"./word2vec.bin\")\n",
    "\n",
    "# convert format\n",
    "model.wv.save_word2vec_format('./word2vec.txt', binary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**关于Datawhale：**\n",
    "\n",
    "> Datawhale是一个专注于数据科学与AI领域的开源组织，汇集了众多领域院校和知名企业的优秀学习者，聚合了一群有开源精神和探索精神的团队成员。Datawhale 以“for the learner，和学习者一起成长”为愿景，鼓励真实地展现自我、开放包容、互信互助、敢于试错和勇于担当。同时 Datawhale 用开源的理念去探索开源内容、开源学习和开源方案，赋能人才培养，助力人才成长，建立起人与人，人与知识，人与企业和人与未来的联结。\n",
    "\n",
    "本次新闻文本分类学习，专题知识将在天池分享，详情可关注Datawhale：\n",
    "\n",
    " ![](http://jupter-oss.oss-cn-hangzhou.aliyuncs.com/public/files/image/1095279172547/1584432602983_kAxAvgQpG2.jpg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
